-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
110 lines (83 loc) · 2.39 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Replicating the mushroom_model behavior
Add number_input widgets for continuous values
"""
import streamlit as st
import pandas as pd
from src.mushroom_model import predict_mushroom
# # # Header # # #
st.markdown("# Mushroom Classifier")
# what is noqa?
# https://flake8.pycqa.org/en/3.1.1/user/ignoring-errors.html
st.markdown("""
Enter the properties of your mushroom below.
> Note: this model was trained on a generated database.
Do not use it to classify real mushrooms!
""") # noqa w291
# # # Initial Observation State # # #
observation = {
'cap-diameter': [50],
'stem-height': [20],
'stem-width': [30],
'has-ring': ['t'],
'cap-shape': ['c']
}
# # # User Input Widgets # # #
has_ring = st.checkbox("The mushroom has a ring")
observation['has-ring'] = ['t'] if has_ring else ['f']
cap_shape = st.selectbox(
"Mushroom Cap Shape",
options=[
'conical', 'bell', 'convex', 'flat',
'sunken', 'spherical', 'others']
)
cap_shape_map = {
'conical': 'c',
'bell': 'b',
'convex': 'x',
'flat': 'f',
'sunken': 's',
'spherical': 'p',
'others': 'o'
}
observation['cap-shape'] = [cap_shape_map[cap_shape]]
cap_diameter = st.number_input(
"Cap diameter (cm)",
min_value=0.38,
max_value=62.34,
value=50.0,
help="Cap diameter from 0.38 to 62.34 cm"
)
observation['cap-diameter'] = [cap_diameter]
stem_height = st.number_input(
"Stem height (cm)",
min_value=0.00,
max_value=33.92,
value=20.0,
help="Stem height from 0 to 33.92 cm"
)
observation['stem-height'] = [stem_height]
stem_width = st.number_input(
"Stem width (mm)",
min_value=0.00,
max_value=103.91,
value=30.0,
help="Stem width from 0 to 103.91 mm"
)
observation['stem-width'] = [stem_width]
# # # Prediction and Display # # #
single_obs_df = pd.DataFrame(observation)
# so far there's only one prediction so we'll index that prediction
current_prediction = predict_mushroom(single_obs_df)[0]
# note that these still print to the console
print(f"model results: {current_prediction}")
print(observation)
# Streamlit will happily take emojis
if current_prediction == 0:
st.markdown("### 🍄🍄🍄 Mushroom is not poisonous")
else:
st.markdown("### 🤢🤮💀 Mushroom is poisonous!")
st.markdown("""
> Note: this model was trained on a generated database.
Do not use it to classify real mushrooms!
""") # noqa w291