32 lines
851 B
Python
32 lines
851 B
Python
import cv2
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from skimage.feature import hog
|
|
from sklearn.neighbors import KNeighborsClassifier
|
|
from sklearn.metrics import accuracy_score
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.metrics import classification_report
|
|
from sklearn import svm
|
|
from sklearn import tree
|
|
import pickle
|
|
|
|
from implement import train_and_save_model, save_data, get_data, load_data
|
|
|
|
print("Train")
|
|
|
|
img_data_path = "Data\\train_img_data.npy"
|
|
marks_data_path = "Data\\train_marks.npy"
|
|
train_baza_path = "Baza\\Pet1000"
|
|
|
|
print("loading data")
|
|
if False:
|
|
img_data, marks = get_data(train_baza_path)
|
|
save_data(img_data, marks, img_data_path, marks_data_path)
|
|
else:
|
|
img_data, marks = load_data(img_data_path, marks_data_path)
|
|
|
|
|
|
print("Training")
|
|
train_and_save_model(img_data, marks)
|
|
print("End")
|