]> git.angelumana.com Git - wu-api/.git/commitdiff
Add loading and making predictions with model, reporting prediction results main origin/HEAD origin/main
authoribidyouadu <angel.d.umana@gmail.com>
Wed, 24 Jul 2024 04:39:18 +0000 (23:39 -0500)
committeribidyouadu <angel.d.umana@gmail.com>
Wed, 24 Jul 2024 04:39:18 +0000 (23:39 -0500)
app/data.py
app/inference.py
app/main.py
app/model.py
app/params.py [new file with mode: 0644]
app/templates/result.html

index 464090415c47109523e91779d4f40e19495c9cf1..5df7cce596b38dfdb9e195fab9c008b76147b92f 100644 (file)
@@ -1 +1,49 @@
-# TODO
+from PIL import Image
+from io import BytesIO
+import base64
+import tensorflow as tf
+from tensorflow import keras
+from sklearn.preprocessing import LabelEncoder
+
+def load_image(contents):
+    """
+    Given a base64 decoded image from a POST request to /result, return a tensorflow tensor
+    representation of the image data.
+
+    Parameters
+    ----------
+    contents : bytes
+        Output from UploadFile.read()
+    
+    Returns
+    -------
+    image : tf.Tensor
+        Tensor representation of the image data
+    """
+    bytes_image = BytesIO(contents)
+
+    # Load image as batch tensor
+    # https://www.tensorflow.org/api_docs/python/tf/keras/utils/load_img
+    pil_image = keras.utils.load_img(bytes_image)
+    array_image = keras.utils.img_to_array(pil_image)
+    image = tf.convert_to_tensor(array_image)
+
+    return image
+
+def preprocess_image(image):
+    """
+    Normalize pixel values, reshape image if necessary, and convert to greyscale.
+    """
+    # Normalize pixel values
+    image = image/255
+    image = tf.cast(image,  tf.float32)
+
+    # Reshape
+    if image.shape != (400, 400, 3):
+        image = tf.image.resize(image, [400, 400])
+
+    # Convert to greyscale
+    image = tf.tensordot(image, tf.constant([0.299, 0.587, 0.114]), axes=[[2], [0]])
+    # grey_image = tf.expand_dims(grey_image, -1)
+
+    return image
\ No newline at end of file
index 464090415c47109523e91779d4f40e19495c9cf1..3b6fa4d31b22cd2fd94cef41447465e3a04eaa91 100644 (file)
@@ -1 +1,47 @@
-# TODO
+from datasets import load_dataset
+from tensorflow import keras
+import tensorflow as tf
+import numpy as np
+from sklearn.preprocessing import LabelEncoder
+from params import DATASET_REPO_ID, label_names_to_english
+
+
+def get_label_encoder():
+    """
+    Create LabelEncoder object to translate integers to text.
+    """
+    dataset_info = load_dataset(DATASET_REPO_ID, split="train", streaming=True)._info
+    labels = dataset_info.features['label'].names
+    labels = [label_names_to_english[l] for l in labels]
+    le = LabelEncoder()
+    le.fit(labels)
+
+    return le
+
+def make_prediction(model, input):
+    """
+    Make prediction for image and return the label and associated probability.
+
+    Parameters
+    ----------
+    model : keras.Sequential
+        The cloud identifier model
+    input : tf.Tensor
+        Output from preproceess_data()
+
+    Returns
+    -------
+    predicted_label_name : str
+        Name of the most likely label
+    predicted_proba : np.float64
+        Probability of the most likely label
+    """
+    batch = tf.expand_dims(input, axis=0)
+    probabilities = model.predict(batch)
+    predicted = np.argmax(probabilities)
+    predicted_proba = round(np.max(probabilities) * 100, 1)
+
+    le = get_label_encoder()
+    predicted_label_name = le.inverse_transform(predicted.reshape(1,))[0]
+
+    return predicted_label_name, predicted_proba
\ No newline at end of file
index 40884936d4a43ea3dcdcf3f0e905356f33f5e868..0d3e7b4dc6fe7a383264bbae5bbfc85c69dea642 100644 (file)
@@ -4,11 +4,14 @@ from fastapi.templating import Jinja2Templates
 from tempfile import NamedTemporaryFile
 import os
 import base64
 from tempfile import NamedTemporaryFile
 import os
 import base64
+from data import load_image, preprocess_image
+from model import get_model
+from inference import make_prediction, get_label_encoder
 
 app = FastAPI()
 templates = Jinja2Templates("templates")
 
 
 app = FastAPI()
 templates = Jinja2Templates("templates")
 
-@app.get("/", response_class=HTMLResponse)
+@app.get("/", response_class = HTMLResponse)
 def index(request: Request):
     context = {"request": request}
     response = templates.TemplateResponse("index.html", context)
 def index(request: Request):
     context = {"request": request}
     response = templates.TemplateResponse("index.html", context)
@@ -19,8 +22,17 @@ def index(request: Request):
 def show_image(request: Request, background_tasks: BackgroundTasks, input_image: UploadFile = File(...)):
     contents = input_image.file.read()
     encoded_image = base64.b64encode(contents).decode("utf-8")
 def show_image(request: Request, background_tasks: BackgroundTasks, input_image: UploadFile = File(...)):
     contents = input_image.file.read()
     encoded_image = base64.b64encode(contents).decode("utf-8")
-    
-    context = {"request": request, "image": encoded_image}
+    raw_image = load_image(contents)
+    image = preprocess_image(raw_image)
+    model = get_model()
+    predicted_label, predicted_probability = make_prediction(model, image)
+
+    context = {
+        "request": request,
+        "image": encoded_image,
+        "predicted_label": predicted_label,
+        "predicted_probability": f"{predicted_probability}%"
+    }
     response = templates.TemplateResponse("result.html", context)
     return response
 
     response = templates.TemplateResponse("result.html", context)
     return response
 
index 464090415c47109523e91779d4f40e19495c9cf1..8104469b56e65e53ea8aa1291e2f8b1b4ad5ac83 100644 (file)
@@ -1 +1,11 @@
-# TODO
+from huggingface_hub import hf_hub_download
+from tensorflow import keras
+from params import MODEL_REPO_ID, MODEL_FILENAME
+def get_model():
+    """
+    Download and load the model from huggingface
+    """
+    model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
+    model = keras.models.load_model(model_path)
+
+    return model
\ No newline at end of file
diff --git a/app/params.py b/app/params.py
new file mode 100644 (file)
index 0000000..a78bce1
--- /dev/null
@@ -0,0 +1,16 @@
+DATASET_REPO_ID = "aduuuuuu/CCSN"
+MODEL_REPO_ID = "aduuuuuu/wu"
+MODEL_FILENAME = "model.h5"
+label_names_to_english = {
+    "Ac": "altocumulus",
+    "As": "altostratus",
+    "Cb": "cumulonimbus",
+    "Cc": "cumulus",
+    "Ci": "cirrus",
+    "Cs": "cirrostratuss",
+    "Ct": "contrail",
+    "Cu": "cumulus",
+    "Ns": "nimbostratus",
+    "Sc": "stratocumulus",
+    "St": "stratus"
+}
\ No newline at end of file
index 635ff3093bcf533ee4cb2f05ab2dbcdd3be61e30..dbe2148c4b453e470ffbc4d57a5d5547682855d2 100644 (file)
@@ -1,6 +1,6 @@
 <!DOCTYPE html>
 <html><body>
 <!DOCTYPE html>
 <html><body>
-    <p>Wow what a lovely pic !!!</p>
+    <p>Wu say this is a {{ predicted_label }} with probability {{ predicted_probability }}</p>
     <img src="data:image/jpeg;base64,{{image | safe}}" />
     <a href="/">Do it again, do it again!</a>
 </body></html>
\ No newline at end of file
     <img src="data:image/jpeg;base64,{{image | safe}}" />
     <a href="/">Do it again, do it again!</a>
 </body></html>
\ No newline at end of file