{-
    Kaya - My favourite toy language.
    Copyright (C) 2004, 2005 Edwin Brady

    This file is distributed under the terms of the GNU General
    Public Licence. See COPYING for licence.
-}

module Language(module AbsSyntax, module Language) where

----- Useful gadgets on the abstract syntax; also reexports syntax types.

import List
import Debug.Trace
import AbsSyntax

-- Merge somehow with 'Error' and tidy up name lookup functions.
data Lookup a = Got a
              | Ambiguous Name [Name]
              | UnknownName Name

instance Show a => Show (Lookup a) where
    show x = s' x
      where
        s' (Got a) = "Success: " ++ show a
        s' (Ambiguous n xs) = "Ambiguous name: " ++ showuser n ++
                              " (Could be " ++ showthings (sort xs) ++ ")"
        s' (UnknownName n) = "Unknown name: " ++ showuser n
	showthings [] = ""
	showthings [x] = showuser x
	showthings (x:xs) = showuser x ++ ", " ++ showthings xs


-- Lookup a name, bearing in mind namespaces.
-- Returns all possibilities, in the current module and others.
lookupname :: Name -> -- current module
              Name -> -- name to lookup (possibly decorated)
	      [(Name,a)] -> [(Name,a)]
lookupname mod n gam = checkCurrent $ lu n (decorated n) gam [] where
    lu n _ [] acc = acc
    lu n True ((x,a):xs) acc | n==x = lu n True xs ((x,a):acc)
			     | otherwise = lu n True xs acc
    lu n False ((x,a):xs) acc | nameMatches n x = lu n False xs ((x,a):acc)
			      | otherwise = lu n False xs acc

    -- in theory, returns names in the current module if they exist, or
    -- all names if the name is not in the current module. But in the
    -- presence of ad-hoc overloading, I don't think this makes sense.
    checkCurrent xs = xs {- cc xs xs
    cc [] xs = xs
    cc ((NS m n,x):xs) _ | m == mod = (NS m n,x):(cc xs [])
    cc (_:ys) xs = cc ys xs -}

    decorated (NS _ _) = True
    decorated _ = False
    nameMatches n (NS _ a) = nameMatches n a
    nameMatches n x = n == x



-- Lookup in the context (ignores namespace if no ambiguity, returns
-- fully qualified name)
-- FIXME: This should return a Lookup structure, to be more informative.
ctxtlookup :: Monad m => Name -> -- Current module
	                 Name -> Context -> 
                         Maybe Type -> -- type information to help disambiguate
                         m (Name, Type)
ctxtlookup mod n ctx ty
    = do let opts = nub (lookupname mod n ctx)
	 let pub = filter accessible opts
	 let priv = opts \\ pub
	 returnName (tymatch ty (nubnames pub)) priv

  where returnName [(x,(ty,opts))] _ = return (x,ty)
	returnName [] [] = fail $ "Unknown name " ++ showuser n
	returnName [] priv = fail $ "Can't use private name: " ++ 
			         showuser n ++ " (" ++ showthings priv ++ ")"
	returnName pub _ = fail $ "Ambiguous name: " ++ showuser n ++ 
			          " (Could be " ++ 
				  showthings pub ++ ")"
	nubnames [] = []
	nubnames (f@(x,(ty,opts)):xs) | (x,ty) `elem` (map getpair xs) = nubnames xs
				      | otherwise = f:(nubnames xs)
        getpair (a,(b,c)) = (a,b)

        tymatch ty xs = case filter (matchesty ty) xs of
                           [] -> xs
                           x -> x

        matchesty (Just (Fn _ args _)) (x, ((Fn _ args2 _), opts))
            = and $ zipWith matchArgs args args2
        matchesty _ _ = True

        matchArgs x y | x == y = True
        matchArgs (TyVar _) _ = True
        matchArgs _ (TyVar _) = True
        matchArgs (Array x) (Array y) = matchArgs x y
        matchArgs (User t ts) (User u us) = t == u &&
                                              (and $ zipWith matchArgs ts us)
        matchArgs (Fn _ _ _) (Fn _ _ _) = True
        matchArgs _ _ = False

	-- A private name in the current module is accessible
      	accessible ((NS nmod _),_) | nmod == mod = True
        accessible (_,(_,fopts)) = elem Public fopts

        showthings xs = showStrs (sort $ map showsig xs)
        showsig (x,(ty,_)) = showuser x ++ "(" ++ showargs ty ++ ")"

        showargs (Fn _ args _) = showlist args
        showargs _ = ""

showlist [] = ""
showlist [x] = show x
showlist (x:xs) = show x ++ ", " ++ showlist xs

showStrs [] = ""
showStrs [x] = x
showStrs (x:xs) = x ++ ", " ++ showStrs xs

ctxtdump :: Context -> String
ctxtdump [] = ""
ctxtdump ((n,(ty,_)):xs) = showuser n ++ " :: " ++ show ty ++ "\n" ++ ctxtdump xs


------------ Gadgets -------------

-- Return whether one type is "smaller" than another
-- X < Y if there is a (meaningful?) injection from X to Y.
-- There'll be a better way, if this table gets much bigger.
-- Should these be in a class PartialOrd?
tlt :: PrimType -> PrimType -> Bool
tlt Boolean Number = True
tlt Boolean StringType = True
tlt Character Number = True
tlt Character StringType = True
tlt Number RealNum = True
tlt Number StringType = True
tlt RealNum StringType = True
tlt _ _ = False

biggert :: Type -> Type -> Type
biggert (Prim x) (Prim y) | x `tlt` y = (Prim y)
biggert x y = x

mangling :: Type -> String
mangling t = "_" ++ mangling' t
mangling' (Fn _ args _) = concat (map mangling' args)
mangling' (Array arg) = "a" ++ mangling' arg
mangling' (User n args) = show n ++ concat (map mangling' args)
mangling' (Prim Number) = "i"
mangling' (Prim Character) = "c"
mangling' (Prim Boolean) = "b"
mangling' (Prim RealNum) = "f"
mangling' (Prim StringType) = "s"
mangling' (Prim File) = "F"
mangling' (Prim Pointer) = "p"
mangling' (Prim Exception) = "e"
mangling' (Prim Void) = "v"
mangling' _ = ""

-- Get all of the type variables out of a type.

getTyVars :: Type -> [Name]
getTyVars (TyVar n) = [n]
getTyVars (User _ tys) = concat (map getTyVars tys)
getTyVars (Array t) = getTyVars t
getTyVars (Fn _ tys t) = concat (map getTyVars tys) ++ getTyVars t
getTyVars _ = []

-- C Names need to be mangled with the type, for disambiguation of overloaded
-- functions

type Mangled = String

cname :: Name -> String -> Mangled
cname n mangle = show n ++ mangle

convert :: Type -> Type -> Bool
convert = (==)

checkConv :: Monad m => Type -> Type -> String -> m ()
checkConv x y err = if convert x y 
		     then return ()
		     else fail err

getType :: Monad m => Name -> [(Name,b)] -> m b
getType n ctxt = case (lookup n ctxt) of
		    Nothing -> fail $ "Unknown name gettype " ++ show n
		    (Just t) -> return t

getVars :: Type -> [Name]
getVars = nub.gv
    where gv (Fn ns ts t) = concat (map gv (t:ts))
	  gv (Array t) = gv t
	  gv (User n ts) = concat (map gv ts)
	  gv (TyVar n) = [n]
	  gv _ = []

numargs :: Type -> Int
numargs (Fn ns ts t) = length ts
numargs _ = 0


-- Give distinct type variables fresh names, so that independent variables
-- continue to be independent.

fudgevars :: Type -> Int -> (Type, Int)
fudgevars t next = let (vsmap,next') = newnames next (getUserVars t) in
		       (alpha vsmap t, next')
    where newnames n [] = ([],n)
	  newnames n (x:xs) = let (xsih,n') = newnames (n+1) xs in
				  ((x,MN ("TV",n)):xsih, n')
	  alpha vsmap (Prim x) = Prim x
	  alpha vsmap (Fn ns tys t) = Fn ns (map (alpha vsmap) tys)
				            (alpha vsmap t)
	  alpha vsmap (Array t) = Array (alpha vsmap t)
	  alpha vsmap (User n tys) = User n (map (alpha vsmap) tys)
--	  alpha vsmap (Syn n) = Syn n
	  alpha vsmap (TyVar x) = case lookup x vsmap of
				     Nothing -> TyVar x
				     (Just v) -> TyVar v
	  alpha vsmap UnknownType = UnknownType

getUserVars :: Type -> [Name]
getUserVars = nub.gv
    where gv (Fn ns ts t) = concat (map gv (t:ts))
	  gv (Array t) = gv t
	  gv (User n ts) = concat (map gv ts)
	  gv (TyVar (UN n)) = [UN n]
	  gv _ = []

lvaltoexp :: RAssign -> Raw
lvaltoexp (RAName f l n) = RVar f l n
lvaltoexp (RAIndex f l lv r) = RIndex f l (lvaltoexp lv) r
lvaltoexp (RAField f l lv r) = RField f l (lvaltoexp lv) r

showconst (Num x) = show x
showconst (Ch '\0') = "'\\0'"
showconst (Ch c) = show c
showconst (Bo True) = "true"
showconst (Bo False) = "false"
showconst (Str str) = show str
showconst (Exc str i) = "exception("++show str++","++show i++")"
showconst (Empty) = error "Can't show an empty constant"


-- Lookup in the type context (ignores namespace if no ambiguity, returns
-- fully qualified type name)
typelookup :: Name -> -- Current module
	      Name -> Types -> Lookup (Name, TypeInfo)
typelookup mod t ti = returnName (nubnames (lookupname mod t ti))
   where returnName [x] = Got x
	 returnName [] = UnknownName t
--fail $ "Unknown type " ++ show t
	 returnName xs = Ambiguous t (map fst xs)
--fail $ "Ambiguous type name " ++ showuser t ++ 
--			        " (found " ++ showthings xs ++ ")"
	 showthings [] = ""
	 showthings [(x,_)] = showuser x
	 showthings ((x,_):xs) = showuser x ++ ", " ++ showthings xs

	 nubnames [] = []
	 nubnames (f@(x,_):xs) | x `elem` (map fst xs) = nubnames xs
			       | otherwise = f:(nubnames xs)

-- Type normalisation; expand synonyms.

normalise :: Monad m => String -> Int -> Name -> Types -> Type -> m Type
normalise f l mod ti t = tn [] t
 where
   tn u (Fn ds ts t) 
       = do ts' <- mapM (tn u) ts
	    t' <- tn u t
	    return $ Fn ds ts' t'
   tn u (Array t) = do t' <- tn u t
		       return $ Array t'
   tn u t@(User n ts) = 
       case typelookup mod n ti of
          un@(UnknownName _) -> 
              do ts' <- mapM (tn u) ts
                 return $ User (fixup mod n) ts'
          am@(Ambiguous n xs) -> fail $ f ++ ":" ++ show l ++ ":" ++ show am
	  (Got (fqn, x)) -> applyTI u fqn ts x
   tn u rest = return rest

   fixup m fqn@(NS _ _) = fqn
   fixup m n = (NS m n)
--   fixup m n = n

   applyTI u n ts (UserData as)
	| length ts < length as 
	    = fail $ f ++ ":" ++ show l ++ ":" ++
	       "Type " ++ showuser n ++ " has too few parameters"
	| length ts > length as 
	    = fail $ f ++ ":" ++ show l ++ ":" ++
	       "Type " ++ showuser n ++ " has too many parameters"
	| otherwise = do ts' <- mapM (tn u) ts
			 return $ User n ts'
   -- Replace type with 't', replacing instances of as inside t with
   -- corresponding instances of ts.
   -- That probably makes no sense.
   applyTI u n ts (Syn as t)
	| length ts < length as
	    = fail $ f ++ ":" ++ show l ++ ":" ++
	       "Type synonym " ++ showuser n ++ " has too few parameters"
	| length ts > length as
	    = fail $ f ++ ":" ++ show l ++ ":" ++
	       "Type synonym " ++ showuser n ++ " has too many parameters"
	| otherwise = if elem n u 
		       then fail $ f ++ ":" ++ show l ++ ":" ++
			     "Cycle in type synonyms " ++ 
			     showsyns u
		       else do st <- substty (zip as ts) t
			       tn (n:u) st
     where showsyns [n] = showuser n
	   showsyns (n:ns) = showuser n ++ ", " ++ showsyns ns

   applyTI u n ts Private = fail $ f ++ ":" ++ show l ++ 
			     "Can't use private type " ++ showuser n

   substty tmap (TyVar n) = case lookup n tmap of
			      Nothing -> fail $ "Shouldn't happen" ++ show tmap
			      (Just t) -> return t
   substty tmap (Array t) = do t' <- substty tmap t
			       return $ Array t'
   substty tmap (Fn ds as r) = do as' <- mapM (substty tmap) as
				  r' <- substty tmap r
				  return $ Fn ds as' r'
   substty tmap (User n ts) = do ts' <- mapM (substty tmap) ts
				 return $ User n ts'
   substty _ rest = return rest

-- Fold constants in a raw term
-- TODO/FIXME: Check bounds?
cfold :: Raw -> Raw
cfold r@(RInfix f l op (RConst _ _ (Num x)) (RConst _ _ (Num y)))
    = case (foldint op x y) of
          Just c -> RConst f l c
	  Nothing -> r
cfold r@(RInfix f l op (RConst _ _ (Re x)) (RConst _ _ (Re y)))
    = case (foldreal op x y) of
          Just c -> RConst f l c
	  Nothing -> r
cfold r@(RUnary f l op (RConst _ _ (Num x)))
    = case (foldunint op x) of
          Just c -> RConst f l c
	  Nothing -> r
cfold r@(RUnary f l op (RConst _ _ (Re x)))
    = case (foldunreal op x) of
          Just c -> RConst f l c
	  Nothing -> r
cfold r = r

getConst (RConst _ _ c) = c

foldint :: Op -> Int -> Int -> Maybe Const
foldint Plus x y = Just $ Num (x+y)
foldint Minus x y = Just $ Num (x-y)
foldint Times x y = Just $ Num (x*y)
-- TODO: Should be compile time error
foldint Divide x 0 = Nothing
foldint Divide x y = Just $ Num (x `div` y)
foldint Modulo x y = Just $ Num (x `mod` y)
foldint Power x y = Just $ Num (floor ((fromIntegral x)**(fromIntegral y)))
foldint Equal x y = Just $ Bo (x==y)
foldint NEqual x y = Just $ Bo (x/=y)
foldint OpLT x y = Just $ Bo (x<y)
foldint OpGT x y = Just $ Bo (x>y)
foldint OpLE x y = Just $ Bo (x<=y)
foldint OpGE x y = Just $ Bo (x>=y)
foldint _ x y = Nothing 

foldreal :: Op -> Double -> Double -> Maybe Const
foldreal Plus x y = Just $ Re (x+y)
foldreal Minus x y = Just $ Re (x-y)
foldreal Times x y = Just $ Re (x*y)
-- TODO: Should be compile time error
foldreal Divide x 0 = Nothing
foldreal Divide x y = Just $ Re (x/y)
foldreal Power x y = Just $ Re (x**y)
foldreal Equal x y = Just $ Bo (x==y)
foldreal NEqual x y = Just $ Bo (x/=y)
foldreal OpLT x y = Just $ Bo (x<y)
foldreal OpGT x y = Just $ Bo (x>y)
foldreal OpLE x y = Just $ Bo (x<=y)
foldreal OpGE x y = Just $ Bo (x>=y)
foldreal _ x y = Nothing 

foldunint :: UnOp -> Int -> Maybe Const
foldunint Neg x = Just $ Num (-x)
foldunint _ _ = Nothing

foldunreal :: UnOp -> Double -> Maybe Const
foldunreal Neg x = Just $ Re (-x)
foldunreal _ _ = Nothing

-- Apply a function (non-recursively) to every sub expression,
-- applying a different function to metavariables
-- (I don't know if this is actually that useful, but it is used by the 
-- optimiser...)
mapsubexpr f mf expr = app expr
  where app (Metavar fl l x) = mf fl l x
        app (Lambda ivs args e) = Lambda ivs args (f e)
	app (Closure args t e) = Closure args t (f e)
	app (Bind n ty e1 e2) = Bind n ty (f e1) (f e2)
	app (Declare fn l n t e) = Declare fn l n t (f e)
	app (Return e) = Return (f e)
	app (Assign a e) = Assign (aapply a) (f e)
	app (AssignOp op a e) = AssignOp op (aapply a) (f e)
	app (Seq a b) = Seq (f a) (f b)
	app (Apply fn as) = Apply (f fn) (applys as)
	app (Partial fn as i) = Partial (f fn) (applys as) i
	app (Foreign ty n es) = Foreign ty n 
			        (zip (applys (map fst es)) (map snd es))
	app (While e b) = While (f e) (f b)
	app (DoWhile e b) = DoWhile (f e) (f b)
	app (For i nm j a e1 e2) = For i nm j (aapply a) (f e1) (f e2)
	app (TryCatch t e fl fin) = TryCatch (f t) (f e) (f fl) (f fin)
	app (Throw e) = Throw (f e)
	app (Except e i) = Except (f e) (f i)
	app (Infix op x y) = Infix op (f x) (f y)
	app (CmpStr op x y) = CmpStr op (f x) (f y)
	app (CmpExcept op x y) = CmpExcept op (f x) (f y)
	app (RealInfix op x y) = RealInfix op (f x) (f y)
	app (Append x y) = Append (f x) (f y)
	app (Unary op x) = Unary op (f x)
	app (RealUnary op x) = RealUnary op (f x)
	app (Coerce t1 t2 x) = Coerce t1 t2 (f x)
	app (Case e as) = Case (f e) (altapp as)
	app (If a t e) = If (f a) (f t) (f e)
	app (Index a b) = Index (f a) (f b)
	app (Field e n i j) = Field (f e) n i j
	app (ArrayInit as) = ArrayInit (applys as)
	app (Annotation a e) = Annotation a (f e)
	app x = x

        aapply (AIndex a e) = AIndex (aapply a) (f e)
	aapply (AField a n i j) = AField (aapply a) n i j
	aapply x = x

        applys [] = []
	applys (x:xs) = (f x) : (applys xs)

        altapp [] = []
	altapp ((Alt i j es e):as) 
	    = (Alt i j (applys es) (f e)):(altapp as)
	altapp ((Default e):as) 
	    = (Default (f e)):(altapp as)
	altapp ((ConstAlt pt c e):as) 
	    = (ConstAlt pt c (f e)):(altapp as)

-- Fold a function across all sub expressions.
-- Applies 'f' to the subexpression, and uses 'com' to combine the
-- result across all sub expressions.
foldsubexpr :: (Expr n -> a) -> (a -> a -> a) -> a -> Expr n -> a
foldsubexpr f com def expr = app expr
  where app (Lambda ivs args e) = f e
	app (Closure args t e) = f e
	app (Bind n ty e1 e2) = (f e1) `com` (f e2)
	app (Declare fn l n t e) = (f e)
	app (Return e) = (f e)
	app (Assign a e) = (aapply a) `com` (f e)
	app (AssignOp op a e) = (aapply a) `com` (f e)
	app (Seq a b) = (f a) `com` (f b)
	app (Apply fn as) = (f fn) `com` (applys as)
	app (Partial fn as i) = (f fn) `com` (applys as)
	app (Foreign ty n es) = applys (map fst es)
	app (While e b) = (f e) `com` (f b)
	app (DoWhile e b) = (f e) `com` (f b)
	app (For i nm j a e1 e2) = (aapply a) `com` (f e1) `com` (f e2)
	app (TryCatch t e fl fin) = (f t) `com` (f e) `com` 
				    (f fl) `com` (f fin)
	app (Throw e) = (f e)
	app (Except e i) = (f e) `com` (f i)
	app (Infix op x y) = (f x) `com` (f y)
	app (CmpStr op x y) = (f x) `com` (f y)
	app (CmpExcept op x y) = (f x) `com` (f y)
	app (RealInfix op x y) = (f x) `com` (f y)
	app (Append x y) = (f x) `com` (f y)
	app (Unary op x) = (f x)
	app (RealUnary op x) = (f x)
	app (Coerce t1 t2 x) = (f x)
	app (Case e as) = (f e) `com` (altapp as)
	app (If a t e) = (f a) `com` (f t) `com` (f e)
	app (Index a b) = (f a) `com` (f b)
	app (Field e n i j) = (f e)
	app (ArrayInit as) = (applys as)
	app (Annotation a e) = (f e)
	app x = def

        aapply (AIndex a e) = (aapply a) `com` (f e)
	aapply (AField a n i j) = (aapply a)
	aapply x = def

        applys [] = def
	applys (x:xs) = (f x) `com` (applys xs)

        altapp [] = def
	altapp ((Alt i j es e):as) 
	    = (applys es) `com` (f e) `com` (altapp as)
	altapp ((Default e):as) 
	    = (f e) `com` (altapp as)
	altapp ((ConstAlt pt c e):as) 
	    = (f e) `com` (altapp as)

