diff --git a/distribusi.go b/distribusi.go
index 6f42ea6..8ac3614 100644
--- a/distribusi.go
+++ b/distribusi.go
@@ -564,24 +564,16 @@ func getCaption(c *cli.Context, fpath string) (string, error) {
}
// parseMtype parses a mimetype string to simplify programmatic type lookups.
-func parseMtype(mtype string) (string, string) {
+func parseMtype(mtype string) (string, string, error) {
+ if !strings.Contains(mtype, "/") {
+ return "", "", fmt.Errorf("unable to parse %s", mtype)
+ }
+
stripCharset := strings.Split(mtype, ";")
splitTypes := strings.Split(stripCharset[0], "/")
-
ftype, stype := splitTypes[0], splitTypes[1]
- return ftype, stype
-}
-
-// sliceContains checks if an element is present in a list.
-func sliceContains(items []string, target string) bool {
- for _, item := range items {
- if item == target {
- return true
- }
- }
-
- return false
+ return ftype, stype, nil
}
// trimFinalNewline trims newlines from the end of bytes just read from files.
@@ -604,7 +596,10 @@ func mkHref(c *cli.Context, fpath string, mtype string) (bool, string, error) {
var unknown bool
fname := filepath.Base(fpath)
- ftype, stype := parseMtype(mtype)
+ ftype, stype, err := parseMtype(mtype)
+ if err != nil {
+ return unknown, href, err
+ }
if ftype == "text" {
fcontents, err := os.ReadFile(fpath)
@@ -680,7 +675,10 @@ func mkDiv(c *cli.Context, mtype string, href, fname string, unknown bool) (stri
filename := fmt.Sprintf("%s", fname)
- ftype, stype := parseMtype(mtype)
+ ftype, stype, err := parseMtype(mtype)
+ if err != nil {
+ return div, err
+ }
if ftype == "text" {
divTemplate = "
%s%s
"
diff --git a/distribusi_test.go b/distribusi_test.go
new file mode 100644
index 0000000..b44ca97
--- /dev/null
+++ b/distribusi_test.go
@@ -0,0 +1,65 @@
+package main
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+func TestTrimFinalNewline(t *testing.T) {
+ trimmed := trimFinalNewline([]byte("foo\n"))
+ if trimmed != "foo" {
+ t.Fatalf("failed to trimmed new line from 'foo\\n'")
+ }
+}
+
+func TestTrimAllNewlines(t *testing.T) {
+ trimmed := trimAllNewlines("\nfoo\n")
+ if trimmed != "foo" {
+ t.Fatalf("failed to trimmed new line from '\\nfoo\\n'")
+ }
+}
+
+func TestParseMtype(t *testing.T) {
+ mtype := "os/directory"
+ ftype, stype, err := parseMtype(mtype)
+ if ftype != "os" || stype != "directory" {
+ t.Fatalf("unable to parse %s", mtype)
+ }
+
+ mtype = "image/gif; charset=utf-8"
+ ftype, stype, err = parseMtype(mtype)
+ if err != nil {
+ t.Fatalf("failed to parse %s", mtype)
+ }
+ if ftype != "image" || stype != "gif" {
+ t.Fatalf("unable to parse %s", mtype)
+ }
+
+ mtype = "notamimetype"
+ ftype, stype, err = parseMtype(mtype)
+ if err == nil {
+ t.Fatalf("failed to error out correctly parsing %s", mtype)
+ }
+}
+
+func TestGetLogFile(t *testing.T) {
+ f, err := getLogFile()
+ if err != nil {
+ t.Fatalf("failed to create log file, saw: %s", err)
+ }
+
+ if !strings.Contains(f.Name(), "distribusi-go") {
+ t.Fatalf("log file named incorrectly: %s", f.Name())
+ }
+
+ absPath, err := filepath.Abs(f.Name())
+ if err != nil {
+ t.Fatalf("failed to read absoluate path of %s", f.Name())
+ }
+
+ if err := os.Remove(absPath); err != nil {
+ t.Fatalf("unable to remove %s", absPath)
+ }
+}