-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.py
65 lines (57 loc) · 1.78 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import io
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
from PIL import Image
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from PIL import Image
class classifier(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3,6,3,1)
self.conv2 = nn.Conv2d(6,16,3,1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(54*54*16,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,2)
def forward(self,X):
X = self.pool(F.relu(self.conv1(X)))
X = self.pool(F.relu(self.conv2(X)))
X = X.view(-1, 54*54*16)
X = F.relu(self.fc1(X))
X = F.relu(self.fc2(X))
X = self.fc3(X)
return F.log_softmax(X, dim=1)
class_names = ['CAT', 'DOG']
def get_model():
checkpoint = './model_state.pth'
model = classifier()
model.load_state_dict(torch.load(checkpoint,map_location='cpu'),strict = False)
model.eval()
return model
def get_tensor(image_bytes):
my_transforms = transforms.Compose([
transforms.Resize(254),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485,0.456,0.406],
[0.229,0.224,0.225]
)])
image = Image.open(io.BytesIO(image_bytes))
image = my_transforms(image).unsqueeze(0)
image = image.view(1,3,224,224)
return image
def get_output(image_bytes):
with torch.no_grad():
model = get_model()
tensor = get_tensor(image_bytes)
prediction = model(tensor).argmax()
#prediction = model(image).argmax()
prediction = class_names[prediction.item()]
return prediction
# timg = Image.open(r'./static//10000.jpg')
# data = get_output(image_bytes=timg)
# print(data)