forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCSession.cs
More file actions
97 lines (83 loc) · 2.88 KB
/
CSession.cs
File metadata and controls
97 lines (83 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using Tensorflow.Util;
namespace TensorFlowNET.UnitTest
{
/// <summary>
/// tensorflow\c\c_test_util.cc
/// TEST(CAPI, Session)
/// </summary>
public class CSession
{
private IntPtr session_;
private List<TF_Output> inputs_ = new List<TF_Output>();
private List<Tensor> input_values_ = new List<Tensor>();
private List<TF_Output> outputs_ = new List<TF_Output>();
private List<Tensor> output_values_ = new List<Tensor>();
private List<IntPtr> targets_ = new List<IntPtr>();
public CSession(Graph graph, Status s, bool user_XLA = false)
{
lock (Locks.ProcessWide)
{
var opts = new SessionOptions();
opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4});
session_ = new Session(graph, opts, s);
}
}
public void SetInputs(Dictionary<Operation, Tensor> inputs)
{
DeleteInputValues();
inputs_.Clear();
foreach (var input in inputs)
{
inputs_.Add(new TF_Output(input.Key, 0));
input_values_.Add(input.Value);
}
}
private void DeleteInputValues()
{
//clearing is enough as they will be disposed by the GC unless they are referenced else-where.
input_values_.Clear();
}
public void SetOutputs(TF_Output[] outputs)
{
ResetOutputValues();
outputs_.Clear();
foreach (var output in outputs)
{
outputs_.Add(output);
output_values_.Add(IntPtr.Zero);
}
}
private void ResetOutputValues()
{
//clearing is enough as they will be disposed by the GC unless they are referenced else-where.
output_values_.Clear();
}
public unsafe void Run(Status s)
{
var inputs_ptr = inputs_.ToArray();
var input_values_ptr = input_values_.Select(x => (IntPtr) x).ToArray();
var outputs_ptr = outputs_.ToArray();
var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray();
IntPtr[] targets_ptr = new IntPtr[0];
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length,
outputs_ptr, output_values_ptr, outputs_.Count,
targets_ptr, targets_.Count,
IntPtr.Zero, s);
s.Check();
output_values_[0] = output_values_ptr[0];
}
public IntPtr output_tensor(int i)
{
return output_values_[i];
}
public void CloseAndDelete(Status s)
{
DeleteInputValues();
ResetOutputValues();
}
}
}