NimurAI commited on
Commit
5a3f083
·
verified ·
1 Parent(s): a076fe5

Upload flutter_integration_example.dart with huggingface_hub

Browse files
Files changed (1) hide show
  1. flutter_integration_example.dart +244 -0
flutter_integration_example.dart ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import 'dart:io';
2
+ import 'dart:typed_data';
3
+ import 'dart:ui' as ui;
4
+ import 'package:flutter/services.dart';
5
+ import 'package:flutter_pytorch_lite/flutter_pytorch_lite.dart';
6
+
7
+ class PlantAnomalyDetector {
8
+ Module? _module;
9
+ static const double _threshold = 0.5687; // Your threshold from training
10
+
11
+ // Normalization values from your training data
12
+ static const List<double> _mean = [0.4682, 0.4865, 0.3050];
13
+ static const List<double> _std = [0.2064, 0.1995, 0.1961];
14
+
15
+ /// Initialize the model from assets
16
+ Future<void> loadModel() async {
17
+ try {
18
+ // Load model from assets
19
+ final filePath = '${Directory.systemTemp.path}/plant_anomaly_detector.ptl';
20
+ final modelBytes = await _getBuffer('assets/models/plant_anomaly_detector.ptl');
21
+ File(filePath).writeAsBytesSync(modelBytes);
22
+
23
+ _module = await FlutterPytorchLite.load(filePath);
24
+ print('Model loaded successfully');
25
+ } catch (e) {
26
+ print('Error loading model: $e');
27
+ rethrow;
28
+ }
29
+ }
30
+
31
+ /// Get byte buffer from assets
32
+ static Future<Uint8List> _getBuffer(String assetFileName) async {
33
+ ByteData rawAssetFile = await rootBundle.load(assetFileName);
34
+ final rawBytes = rawAssetFile.buffer.asUint8List();
35
+ return rawBytes;
36
+ }
37
+
38
+ /// Normalize tensor values using training statistics
39
+ List<double> _normalize(List<double> input) {
40
+ List<double> normalized = [];
41
+ int channels = 3;
42
+ int pixelsPerChannel = input.length ~/ channels;
43
+
44
+ for (int c = 0; c < channels; c++) {
45
+ for (int i = 0; i < pixelsPerChannel; i++) {
46
+ int idx = c * pixelsPerChannel + i;
47
+ double normalizedValue = (input[idx] - _mean[c]) / _std[c];
48
+ normalized.add(normalizedValue);
49
+ }
50
+ }
51
+
52
+ return normalized;
53
+ }
54
+
55
+ /// Calculate reconstruction error (MSE) between original and reconstructed
56
+ double _calculateReconstructionError(List<double> original, List<double> reconstructed) {
57
+ if (original.length != reconstructed.length) {
58
+ throw ArgumentError('Original and reconstructed tensors must have same length');
59
+ }
60
+
61
+ double sumSquaredError = 0.0;
62
+ for (int i = 0; i < original.length; i++) {
63
+ double diff = original[i] - reconstructed[i];
64
+ sumSquaredError += diff * diff;
65
+ }
66
+
67
+ return sumSquaredError / original.length;
68
+ }
69
+
70
+ /// Detect if an image is a plant or anomaly
71
+ Future<PlantDetectionResult> detectPlant(ui.Image image) async {
72
+ if (_module == null) {
73
+ throw StateError('Model not loaded. Call loadModel() first.');
74
+ }
75
+
76
+ try {
77
+ // Convert image to tensor
78
+ final inputShape = Int64List.fromList([1, 3, 224, 224]);
79
+ Tensor inputTensor = await TensorImageUtils.imageToFloat32Tensor(
80
+ image,
81
+ width: 224,
82
+ height: 224,
83
+ );
84
+
85
+ // Get original normalized values for reconstruction error calculation
86
+ List<double> originalValues = inputTensor.dataAsFloat32List;
87
+ List<double> normalizedOriginal = _normalize(originalValues);
88
+
89
+ // Forward pass through the model
90
+ IValue input = IValue.from(inputTensor);
91
+ IValue output = await _module!.forward([input]);
92
+
93
+ // Get reconstruction
94
+ Tensor reconstructionTensor = output.toTensor();
95
+ List<double> reconstruction = reconstructionTensor.dataAsFloat32List;
96
+
97
+ // Calculate reconstruction error
98
+ double reconstructionError = _calculateReconstructionError(
99
+ normalizedOriginal,
100
+ reconstruction
101
+ );
102
+
103
+ // Determine if it's an anomaly
104
+ bool isAnomaly = reconstructionError > _threshold;
105
+ double confidence = (reconstructionError - _threshold).abs() / _threshold;
106
+
107
+ return PlantDetectionResult(
108
+ isPlant: !isAnomaly,
109
+ reconstructionError: reconstructionError,
110
+ threshold: _threshold,
111
+ confidence: confidence,
112
+ );
113
+
114
+ } catch (e) {
115
+ print('Error during inference: $e');
116
+ rethrow;
117
+ }
118
+ }
119
+
120
+ /// Dispose the model
121
+ Future<void> dispose() async {
122
+ if (_module != null) {
123
+ await _module!.destroy();
124
+ _module = null;
125
+ }
126
+ }
127
+ }
128
+
129
+ /// Result class for plant detection
130
+ class PlantDetectionResult {
131
+ final bool isPlant;
132
+ final double reconstructionError;
133
+ final double threshold;
134
+ final double confidence;
135
+
136
+ PlantDetectionResult({
137
+ required this.isPlant,
138
+ required this.reconstructionError,
139
+ required this.threshold,
140
+ required this.confidence,
141
+ });
142
+
143
+ @override
144
+ String toString() {
145
+ return 'PlantDetectionResult('
146
+ 'isPlant: $isPlant, '
147
+ 'reconstructionError: ${reconstructionError.toStringAsFixed(4)}, '
148
+ 'threshold: ${threshold.toStringAsFixed(4)}, '
149
+ 'confidence: ${(confidence * 100).toStringAsFixed(2)}%'
150
+ ')';
151
+ }
152
+ }
153
+
154
+ /// Example usage in a Flutter widget
155
+ class PlantDetectionWidget extends StatefulWidget {
156
+ @override
157
+ _PlantDetectionWidgetState createState() => _PlantDetectionWidgetState();
158
+ }
159
+
160
+ class _PlantDetectionWidgetState extends State<PlantDetectionWidget> {
161
+ final PlantAnomalyDetector _detector = PlantAnomalyDetector();
162
+ bool _isModelLoaded = false;
163
+
164
+ @override
165
+ void initState() {
166
+ super.initState();
167
+ _loadModel();
168
+ }
169
+
170
+ Future<void> _loadModel() async {
171
+ try {
172
+ await _detector.loadModel();
173
+ setState(() {
174
+ _isModelLoaded = true;
175
+ });
176
+ } catch (e) {
177
+ print('Failed to load model: $e');
178
+ }
179
+ }
180
+
181
+ Future<void> _detectFromAsset(String assetPath) async {
182
+ if (!_isModelLoaded) return;
183
+
184
+ try {
185
+ // Load image from assets
186
+ const assetImage = AssetImage('assets/images/test_plant.jpg');
187
+ final image = await TensorImageUtils.imageProviderToImage(assetImage);
188
+
189
+ // Run detection
190
+ final result = await _detector.detectPlant(image);
191
+
192
+ // Show result
193
+ print('Detection result: $result');
194
+
195
+ // You can update UI here with the result
196
+ showDialog(
197
+ context: context,
198
+ builder: (context) => AlertDialog(
199
+ title: Text(result.isPlant ? 'Plant Detected' : 'Anomaly Detected'),
200
+ content: Text(
201
+ 'Reconstruction Error: ${result.reconstructionError.toStringAsFixed(4)}\n'
202
+ 'Confidence: ${(result.confidence * 100).toStringAsFixed(2)}%'
203
+ ),
204
+ actions: [
205
+ TextButton(
206
+ onPressed: () => Navigator.pop(context),
207
+ child: Text('OK'),
208
+ ),
209
+ ],
210
+ ),
211
+ );
212
+
213
+ } catch (e) {
214
+ print('Error during detection: $e');
215
+ }
216
+ }
217
+
218
+ @override
219
+ void dispose() {
220
+ _detector.dispose();
221
+ super.dispose();
222
+ }
223
+
224
+ @override
225
+ Widget build(BuildContext context) {
226
+ return Scaffold(
227
+ appBar: AppBar(title: Text('Plant Anomaly Detection')),
228
+ body: Center(
229
+ child: Column(
230
+ mainAxisAlignment: MainAxisAlignment.center,
231
+ children: [
232
+ if (!_isModelLoaded)
233
+ CircularProgressIndicator()
234
+ else
235
+ ElevatedButton(
236
+ onPressed: () => _detectFromAsset('assets/images/test_plant.jpg'),
237
+ child: Text('Detect Plant'),
238
+ ),
239
+ ],
240
+ ),
241
+ ),
242
+ );
243
+ }
244
+ }